from typer import Typer
from .utils import Config, console
from oh_my_gpt import registry
from typing import Optional
import torch
import lightning as L
from lightning import seed_everything


BASEC_CONFIG = '''
[training]
compile = true
seed = 42

[datamodule]
@datamodules=language_modeling
dataset_path = 
tokenizer_name_or_path= IDEA-CCNL/Wenzhong2.0-GPT2-3.5B-chinese
batch_size = 32
val_size = 0.01
num_workers=4

[model]
@models=nanogpt.base
lr=4e-4
num_warmup_steps=1000
vocab_size=50257
p_drop=0.1

[trainer]
@trainers=lightning
devices=1
strategy=auto
accelerator=auto
precision=16-mixed
max_epochs=5
limit_train_batches=1.0

[trainer.logger]
@loggers=csv
save_dir=logs

[trainer.*.model_checkpoint]
@callbacks=model_checkpoint
monitor = "val/loss"
dirpath = "./checkpoints"
filename = "{epoch}-{step}-{val/loss:.4f}"

[trainer.*.model_summary]
@callbacks=model_summary
max_depth=2

[trainer.*.rich_progress_bar]
@callbacks=rich_progress_bar


'''


app = Typer(name='Oh My GPT')

@app.command('init')
def init_config(path: str = './config.cfg'):
    """初始化配置文件
    """
    base_config = Config().from_str(BASEC_CONFIG)
    base_config.to_disk(path=path)
    

@app.command('train')
def train_model(config_path: Optional[str] = './config.cfg'):
    """开始模型训练
    """
    if not config_path:
        config = Config().from_str(BASEC_CONFIG)
    else:
        config = Config().from_disk(config_path)
    console.log('resolve config')
    resolved = registry.resolve(config=config)
    
    seed_everything(resolved['training']['seed'])
    
    model=resolved['model']
    
    
    if resolved['training']['compile']:
        console.log('compile pytorch model')
        model = torch.compile(model=model, mode='reduce-overhead')
    
    if torch.cuda.is_available():
        torch.set_float32_matmul_precision("high")
        torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
        torch.backends.cudnn.allow_tf32 = True
    
    trainer: L.Trainer  = resolved['trainer']
    datamodule = resolved['datamodule']
    
    trainer.fit(model=model, datamodule=datamodule)
    
    best_checkpoint_path = trainer.checkpoint_callback.best_model_path